notebooks/Unit 9 - Model Optimization/onnx.ipynb (377 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# ONNX Runtime\n", "\n", "In this notebook, we will show how to use ONNX Runtime to accelerate inference of a model trained in PyTorch. In addition, we will use ONNX to quantize the model to int8 precision to further improve performance by reducing the memory footprint. We will train a simple model on the MNIST dataset and then convert it to ONNX format. We will then use ONNX Runtime to accelerate inference of the model. Finally, we will quantize the model to int8 precision\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup ONNX Runtime\n", "\n", "First, install torch, torchvision, onnx and onnxruntime. Then, import neccesary module" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install torch torchvision\n", "%pip install onnx onnxruntime" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from torchvision import datasets, transforms\n", "import torch.quantization\n", "import pathlib\n", "import numpy as np\n", "import torch.onnx\n", "import onnx\n", "import onnxruntime\n", "from onnxruntime.quantization import quantize_dynamic, quantize_static, CalibrationDataReader, QuantType" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train Model\n", "\n", "We will train a simple CNN model on the MNIST dataset. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "transform=transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.1307,), (0.3081,))\n", " ])\n", "\n", "train_dataset = datasets.MNIST('./data', train=True, download=True,transform=transform)\n", "test_dataset = datasets.MNIST('./data', train=False,transform=transform)\n", "\n", "class Net(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.conv1 = nn.Conv2d(in_channels=1, out_channels=12, kernel_size=3)\n", " self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\n", " self.fc = nn.Linear(12 * 13 * 13, 10)\n", "\n", " def forward(self, x):\n", " x = x.view(-1, 1, 28, 28) \n", " x = F.relu(self.conv1(x))\n", " x = self.pool(x)\n", " x = x.view(x.size(0), -1) \n", " x = self.fc(x)\n", " output = F.log_softmax(x, dim=1)\n", " return output\n", "\n", "\n", "train_loader = torch.utils.data.DataLoader(train_dataset, 32)\n", "test_loader = torch.utils.data.DataLoader(test_dataset, 32)\n", "\n", "device = \"cpu\"\n", "\n", "epochs = 1\n", "\n", "model = Net().to(device)\n", "optimizer = optim.Adam(model.parameters())\n", "\n", "model.train()\n", "\n", "for epoch in range(1, epochs+1):\n", " for batch_idx, (data, target) in enumerate(train_loader):\n", " data, target = data.to(device), target.to(device)\n", " optimizer.zero_grad()\n", " output = model(data)\n", " loss = F.nll_loss(output, target)\n", " loss.backward()\n", " optimizer.step()\n", " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", " epoch, batch_idx * len(data), len(train_loader.dataset),\n", " 100. * batch_idx / len(train_loader), loss.item()))\n", "\n", "MODEL_DIR = pathlib.Path(\"./onnx_models\")\n", "MODEL_DIR.mkdir(exist_ok=True)\n", "torch.save(model.state_dict(), MODEL_DIR / \"original_model.p\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export to ONNX\n", "\n", "After training, export the model to ONNX format.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x, _ = next(iter(train_loader))\n", "torch.onnx.export(model, \n", " x, \n", " MODEL_DIR / \"mnist_model.onnx\", \n", " export_params=True, \n", " opset_version=10, \n", " do_constant_folding=True, \n", " input_names = ['input'], \n", " output_names = ['output'], \n", " dynamic_axes={'input' : {0 : 'batch_size'}, \n", " 'output' : {0 : 'batch_size'}})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run Inference and Test Simalirity\n", "\n", "Next, validate the converted model by running inference and comparing the results with the PyTorch model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "torch_out = model(x)\n", "\n", "onnx_model = onnx.load(MODEL_DIR / \"mnist_model.onnx\")\n", "onnx.checker.check_model(onnx_model)\n", "\n", "ort_session = onnxruntime.InferenceSession(MODEL_DIR / \"mnist_model.onnx\", providers=[\"CPUExecutionProvider\"])\n", "\n", "def to_numpy(tensor):\n", " return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()\n", "\n", "# compute ONNX Runtime output prediction\n", "ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}\n", "ort_outs = ort_session.run(None, ort_inputs)\n", "\n", "# compare ONNX Runtime and PyTorch results\n", "np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)\n", "\n", "print(\"Exported model has been tested with ONNXRuntime, and the result looks good!\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Quantization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Dynamic Quantization\n", "\n", "Dynamic quantization calculates the parameters to be quantized for activations dynamically. These calculations increase the accuracy of the model but increase the cost of inference as well." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!python -m onnxruntime.quantization.preprocess --input {MODEL_DIR / \"mnist_model.onnx\"} --output {MODEL_DIR / \"mnist_model_processed.onnx\"}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_fp32 = MODEL_DIR / \"mnist_model_processed.onnx\"\n", "model_quant = MODEL_DIR / \"mnist_model_quant.onnx\"\n", "quantized_model = quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compare Size\n", "\n", "Let's compare the size of the original model, the quantized model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%ls -lh {MODEL_DIR}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compare Accuracy\n", "\n", "Let's compare the accuracy of the converted onnx model and the quantized model. The accuracy of the quantized model should be close to the original model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def test_onnx(model_name, data_loader):\n", " onnx_model = onnx.load(model_name)\n", " onnx.checker.check_model(onnx_model)\n", " ort_session = onnxruntime.InferenceSession(model_name)\n", " test_loss = 0\n", " correct = 0\n", " for data, target in data_loader:\n", " ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(data)}\n", " output = ort_session.run(None, ort_inputs)[0]\n", " output = torch.from_numpy(output)\n", " test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss\n", " pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability\n", " correct += pred.eq(target.view_as(pred)).sum().item()\n", "\n", " test_loss /= len(data_loader.dataset)\n", "\n", " return 100. * correct / len(data_loader.dataset)\n", "\n", "acc = test_onnx(MODEL_DIR / \"mnist_model.onnx\", test_loader)\n", "print(f\"Accuracy of the original model is {acc}%\")\n", "\n", "qacc = test_onnx(MODEL_DIR / \"mnist_model_quant.onnx\", test_loader)\n", "print(f\"Accuracy of the quantized model is {qacc}%\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Static Quantization\n", "\n", "For static quantization method parameters are quantized first using the calibration dataset. This method is faster than dynamic quantization but the accuracy is lower. Hence, calbration dataset need to be created using the `CalibrationDataReader` class." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class QuantDR(CalibrationDataReader):\n", " def __init__(self, torch_data_loader, input_name):\n", " self.torch_data_loader = torch_data_loader\n", " self.input_name = input_name\n", " self.datasize = len(torch_data_loader)\n", " self.enum_data = iter(torch_data_loader)\n", "\n", " def to_numpy(self, tensor):\n", " return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()\n", "\n", " def get_next(self):\n", " batch = next(self.enum_data, None)\n", " if batch is not None:\n", " return {self.input_name: self.to_numpy(batch[0])}\n", " else:\n", " return None\n", "\n", " def rewind(self):\n", " self.enum_data = iter(self.torch_data_loader)\n", "\n", "calibration_data = QuantDR(train_loader, ort_session.get_inputs()[0].name)\n", "model__static_quant = MODEL_DIR / \"mnist_model_static_quant.onnx\"\n", "static_quant_model = quantize_static(model_fp32, model__static_quant, calibration_data, weight_type=QuantType.QInt8)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compare Size\n", "\n", "Let's compare the size of the original model and the quantized model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%ls -lh {MODEL_DIR}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compare Accuracy\n", "\n", "Let's compare the accuracy of the converted onnx model and the quantized model. The accuracy of the quantized model should be close to the original model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "static_qacc = test_onnx(model__static_quant, test_loader)\n", "print(f\"Accuracy of the static quantized model is {static_qacc}%\")" ] } ], "metadata": { "kernelspec": { "display_name": "model_optimization", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 2 }